# frozen_string_literal: true

require 'numo/optimize'

require 'rumale/base/classifier'
require 'rumale/probabilistic_output'
require 'rumale/validation'

require_relative 'base_estimator'

module Rumale
  module LinearModel
    # SVC is a class that implements Support Vector Classifier with the squared hinge loss.
    # For multiclass classification problem, it uses one-vs-the-rest strategy.
    #
    # @note
    #   Rumale::SVM provides linear support vector classifier based on LIBLINEAR.
    #   If you prefer execution speed, you should use Rumale::SVM::LinearSVC.
    #   https://github.com/yoshoku/rumale-svm
    #
    # @example
    #   require 'rumale/linear_model/svc'
    #
    #   estimator =
    #     Rumale::LinearModel::SVC.new(reg_param: 1.0)
    #   estimator.fit(training_samples, traininig_labels)
    #   results = estimator.predict(testing_samples)
    class SVC < Rumale::LinearModel::BaseEstimator
      include Rumale::Base::Classifier

      # Return the class labels.
      # @return [Numo::Int32] (shape: [n_classes])
      attr_reader :classes

      # Create a new linear classifier with Support Vector Machine with the squared hinge loss.
      #
      # @param reg_param [Float] The regularization parameter.
      # @param fit_bias [Boolean] The flag indicating whether to fit the bias term.
      # @param bias_scale [Float] The scale of the bias term.
      # @param max_iter [Integer] The maximum number of epochs that indicates
      #   how many times the whole data is given to the training process.
      # @param tol [Float] The tolerance of loss for terminating optimization.
      # @param probability [Boolean] The flag indicating whether to perform probability estimation.
      # @param n_jobs [Integer] The number of jobs for running the fit and predict methods in parallel.
      #   If nil is given, the methods do not execute in parallel.
      #   If zero or less is given, it becomes equal to the number of processors.
      #   This parameter is ignored if the Parallel gem is not loaded.
      # @param verbose [Boolean] The flag indicating whether to output loss during iteration.
      #   'iterate.dat' file is generated by numo-optimize.
      def initialize(reg_param: 1.0, fit_bias: true, bias_scale: 1.0, max_iter: 1000, tol: 1e-4, probability: false,
                     n_jobs: nil, verbose: false)
        super()
        @params = {
          reg_param: reg_param,
          fit_bias: fit_bias,
          bias_scale: bias_scale,
          max_iter: max_iter,
          tol: tol,
          probability: probability,
          n_jobs: n_jobs,
          verbose: verbose
        }
      end

      # Fit the model with given training data.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
      # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
      # @return [SVC] The learned classifier itself.
      def fit(x, y)
        x = Rumale::Validation.check_convert_sample_array(x)
        y = Rumale::Validation.check_convert_label_array(y)
        Rumale::Validation.check_sample_size(x, y)

        @classes = Numo::Int32[*y.to_a.uniq.sort]
        x = expand_feature(x) if fit_bias?

        if multiclass_problem?
          n_classes = @classes.size
          n_features = x.shape[1]
          n_features -= 1 if fit_bias?
          @weight_vec = Numo::DFloat.zeros(n_classes, n_features)
          @bias_term = Numo::DFloat.zeros(n_classes)
          @prob_param = Numo::DFloat.zeros(n_classes, 2)
          models = if enable_parallel?
                     parallel_map(n_classes) do |n|
                       bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
                       partial_fit(x, bin_y)
                     end
                   else
                     Array.new(n_classes) do |n|
                       bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
                       partial_fit(x, bin_y)
                     end
                   end
          models.each_with_index { |model, n| @weight_vec[n, true], @bias_term[n], @prob_param[n, true] = model }
        else
          negative_label = @classes[0]
          bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
          @weight_vec, @bias_term, @prob_param = partial_fit(x, bin_y)
        end

        self
      end

      # Calculate confidence scores for samples.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
      # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence score per sample.
      def decision_function(x)
        x = Rumale::Validation.check_convert_sample_array(x)

        x.dot(@weight_vec.transpose) + @bias_term
      end

      # Predict class labels for samples.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
      # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
      def predict(x)
        x = Rumale::Validation.check_convert_sample_array(x)

        n_samples = x.shape[0]
        predicted = if multiclass_problem?
                      decision_values = decision_function(x)
                      if enable_parallel?
                        parallel_map(n_samples) { |n| @classes[decision_values[n, true].max_index] }
                      else
                        Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] }
                      end
                    else
                      decision_values = decision_function(x).ge(0.0).to_a
                      Array.new(n_samples) { |n| @classes[decision_values[n]] }
                    end
        Numo::Int32.asarray(predicted)
      end

      # Predict probability for samples.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
      # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
      def predict_proba(x)
        x = Rumale::Validation.check_convert_sample_array(x)

        if multiclass_problem?
          probs = 1.0 / (Numo::NMath.exp(@prob_param[true, 0] * decision_function(x) + @prob_param[true, 1]) + 1.0)
          (probs.transpose / probs.sum(axis: 1)).transpose.dup
        else
          n_samples = x.shape[0]
          probs = Numo::DFloat.zeros(n_samples, 2)
          probs[true, 1] = 1.0 / (Numo::NMath.exp(@prob_param[0] * decision_function(x) + @prob_param[1]) + 1.0)
          probs[true, 0] = 1.0 - probs[true, 1]
          probs
        end
      end

      private

      def partial_fit(base_x, bin_y)
        fnc = proc do |w, x, y, reg_param|
          n_samples = x.shape[0]
          z = x.dot(w)
          t = 1 - y * z
          loss = 0.5 * reg_param * w.dot(w) + (x.class.maximum(0, t)**2).sum.fdiv(n_samples)
          indices = t.gt(0)
          grad = reg_param * w
          if indices.count.positive?
            sx = x[indices, true]
            sy = y[indices]
            sz = z[indices]
            grad += 2.fdiv(n_samples) * (sz - sy).dot(sx)
          end
          [loss, grad]
        end

        n_features = base_x.shape[1]
        w_init = Numo::DFloat.zeros(n_features)

        res = Numo::Optimize.minimize(
          fnc: fnc, jcb: true, x_init: w_init, args: [base_x, bin_y, @params[:reg_param]],
          maxiter: @params[:max_iter], factr: @params[:tol] / Numo::Optimize::Lbfgsb::DBL_EPSILON,
          verbose: @params[:verbose] ? 1 : -1
        )

        prb = @params[:probability] ? Rumale::ProbabilisticOutput.fit_sigmoid(base_x.dot(res[:x]), bin_y) : Numo::DFloat[1, 0]
        w, b = split_weight(res[:x])

        [w, b, prb]
      end

      def multiclass_problem?
        @classes.size > 2
      end
    end
  end
end
